﻿using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Data.Metadata.Edm;
using System.Data.Objects;
using System.Data;
using Commons.Entity;

namespace Commons.Data
{
    /// <summary>
    /// 数据上下方基类
    /// </summary>
    public abstract class RepositoryBase
    {
        ObjectContext _context;
        public RepositoryBase(ObjectContext context)
        {
            _context = context;
        }
        #region Method
        private static string GetEqualStatment(string fieldName, int paramId)
        {
            return string.Format("{0} = {1}", fieldName, GetParamTag(paramId));
        }

        private static string GetParamTag(int paramId)
        {
            return "{" + paramId + "}";
        }

        /// <summary>
        /// 得到实体键EntityKey
        /// </summary>
        /// <typeparam name="TEntity"></typeparam>
        /// <returns></returns>
        protected ReadOnlyMetadataCollection<EdmMember> GetPrimaryKey<TEntity>()
        {
            EntitySetBase primaryKey = _context.GetEntitySet(typeof(TEntity));
            ReadOnlyMetadataCollection<EdmMember> arr = primaryKey.ElementType.KeyMembers;
            return arr;
        }

        /// <summary>
        /// SQL操作类型
        /// </summary>
        protected enum SQLType
        {
            Insert,
            Update,
            Delete,
        }

        /// <summary>
        /// 构建Update语句串
        /// </summary>
        /// <typeparam name="TEntity"></typeparam>
        /// <param name="entity"></param>
        /// <returns></returns>
        private Tuple<string, object[]> CreateUpdateSQL<TEntity>(TEntity entity) where TEntity : class
        {
            if (entity == null)
                throw new ArgumentException("The database entity can not be null.");
            List<string> pkList = GetPrimaryKey<TEntity>().Select(i => i.Name).ToList();

            Type entityType = entity.GetType();
            var table = entityType.GetProperties().Where(i =>
                !pkList.Contains(i.Name)
                && i.GetValue(entity, null) != null
                && i.PropertyType != typeof(EntityState)
                && !(i.GetCustomAttributes(false).Length > 0
                && i.GetCustomAttributes(false).Where(j => j.GetType() == typeof(NavigationAttribute)) != null)
                && (i.PropertyType.IsValueType || i.PropertyType == typeof(string)) //过滤导航属性
                 ).ToArray();

            //过滤主键，航行属性，状态属性等
            if (pkList == null || pkList.Count == 0)
                throw new ArgumentException("The Table entity have not a primary key.");
            List<object> arguments = new List<object>();
            StringBuilder builder = new StringBuilder();

            foreach (var change in table)
            {
                if (pkList.Contains(change.Name))
                    continue;
                if (arguments.Count != 0)
                    builder.Append(", ");
                builder.Append(change.Name + " = {" + arguments.Count + "}");
                if (change.PropertyType == typeof(string) || change.PropertyType == typeof(DateTime))
                    arguments.Add("'" + change.GetValue(entity, null).ToString().Replace("'", "char(39)") + "'");
                else
                    arguments.Add(change.GetValue(entity, null));
            }

            if (builder.Length == 0)
                throw new Exception("没有任何属性进行更新");

            builder.Insert(0, " UPDATE " + string.Format("[{0}]", entityType.Name) + " SET ");

            builder.Append(" WHERE ");
            bool firstPrimaryKey = true;

            foreach (var primaryField in pkList)
            {
                if (firstPrimaryKey)
                    firstPrimaryKey = false;
                else
                    builder.Append(" AND ");

                object val = entityType.GetProperty(primaryField).GetValue(entity, null);
                builder.Append(GetEqualStatment(primaryField, arguments.Count));
                arguments.Add(val);
            }
            return new Tuple<string, object[]>(builder.ToString(), arguments.ToArray());

        }

        /// <summary>
        /// 构建Delete语句串
        /// </summary>
        /// <typeparam name="TEntity"></typeparam>
        /// <param name="entity"></param>
        /// <returns></returns>
        private Tuple<string, object[]> CreateDeleteSQL<TEntity>(TEntity entity) where TEntity : class
        {
            if (entity == null)
                throw new ArgumentException("The database entity can not be null.");

            Type entityType = entity.GetType();
            List<string> pkList = GetPrimaryKey<TEntity>().Select(i => i.Name).ToList();
            if (pkList == null || pkList.Count == 0)
                throw new ArgumentException("The Table entity have not a primary key.");

            List<object> arguments = new List<object>();
            StringBuilder builder = new StringBuilder();
            builder.Append(" Delete from " + string.Format("[{0}]", entityType.Name));

            builder.Append(" WHERE ");
            bool firstPrimaryKey = true;

            foreach (var primaryField in pkList)
            {
                if (firstPrimaryKey)
                    firstPrimaryKey = false;
                else
                    builder.Append(" AND ");

                object val = entityType.GetProperty(primaryField).GetValue(entity, null);
                builder.Append(GetEqualStatment(primaryField, arguments.Count));
                arguments.Add(val);
            }
            return new Tuple<string, object[]>(builder.ToString(), arguments.ToArray());
        }

        /// <summary>
        /// 构建Insert语句串
        /// 主键为自增时，如果主键值为0，我们将主键插入到SQL串中
        /// </summary>
        /// <typeparam name="TEntity"></typeparam>
        /// <param name="entity"></param>
        /// <returns></returns>
        private Tuple<string, object[]> CreateInsertSQL<TEntity>(TEntity entity) where TEntity : class
        {
            if (entity == null)
                throw new ArgumentException("The database entity can not be null.");

            Type entityType = entity.GetType();
            var table = entityType.GetProperties().Where(i => i.PropertyType != typeof(EntityKey)
                 && i.PropertyType != typeof(EntityState)
                 && i.Name != "IsValid"
                 && i.GetValue(entity, null) != null
                 && !(i.GetCustomAttributes(false).Length > 0
                 && i.GetCustomAttributes(false).Where(j => j.GetType() == typeof(NavigationAttribute)) != null)
                 && (i.PropertyType.IsValueType || i.PropertyType == typeof(string))).ToArray();//过滤主键，航行属性，状态属性等

            List<string> pkList = GetPrimaryKey<TEntity>().Select(i => i.Name).ToList();
            List<object> arguments = new List<object>();
            StringBuilder fieldbuilder = new StringBuilder();
            StringBuilder valuebuilder = new StringBuilder();

            fieldbuilder.Append(" INSERT INTO " + string.Format("[{0}]", entityType.Name) + " (");

            foreach (var member in table)
            {
                if (pkList.Contains(member.Name) && Convert.ToString(member.GetValue(entity, null)) == "0")
                    continue;
                object value = member.GetValue(entity, null);
                if (value != null)
                {
                    if (arguments.Count != 0)
                    {
                        fieldbuilder.Append(", ");
                        valuebuilder.Append(", ");
                    }

                    fieldbuilder.Append(member.Name);
                    if (member.PropertyType == typeof(string) || member.PropertyType == typeof(DateTime))
                        valuebuilder.Append("'{" + arguments.Count + "}'");
                    else
                        valuebuilder.Append("{" + arguments.Count + "}");
                    if (value.GetType() == typeof(string))
                        value = value.ToString().Replace("'", "char(39)");
                    arguments.Add(value);

                }
            }


            fieldbuilder.Append(") Values (");

            fieldbuilder.Append(valuebuilder.ToString());
            fieldbuilder.Append(");");
            return new Tuple<string, object[]>(fieldbuilder.ToString(), arguments.ToArray());
        }
        /// <summary>
        /// 执行ＳＱＬ，根据ＳＱＬ操作的类型
        /// </summary>
        /// <typeparam name="TEntity"></typeparam>
        /// <param name="list"></param>
        /// <param name="sqlType"></param>
        /// <returns></returns>
        protected string DoSQL<TEntity>(IEnumerable<TEntity> list, SQLType sqlType) where TEntity : class
        {
            StringBuilder sqlstr = new StringBuilder();
            switch (sqlType)
            {
                case SQLType.Insert:
                    list.ToList().ForEach(i =>
                    {
                        Tuple<string, object[]> sql = CreateInsertSQL(i);
                        sqlstr.AppendFormat(sql.Item1, sql.Item2);
                    });
                    break;
                case SQLType.Update:
                    list.ToList().ForEach(i =>
                    {
                        Tuple<string, object[]> sql = CreateUpdateSQL(i);
                        sqlstr.AppendFormat(sql.Item1, sql.Item2);
                    });
                    break;
                case SQLType.Delete:
                    list.ToList().ForEach(i =>
                    {
                        Tuple<string, object[]> sql = CreateDeleteSQL(i);
                        sqlstr.AppendFormat(sql.Item1, sql.Item2);
                    });
                    break;
                default:
                    throw new ArgumentException("请输入正确的参数");
            }
            return sqlstr.ToString();
        }



        #endregion
    }
    /// <summary>
    ///  ObjectContext扩展方法
    /// </summary>
    public static class ObjectContextExtensions
    {
        /// <summary>
        /// 得到实体键
        /// </summary>
        /// <param name="context"></param>
        /// <param name="entityType"></param>
        /// <returns></returns>
        public static EntitySetBase GetEntitySet(this ObjectContext context, Type entityType)
        {
            if (context == null)
            {
                throw new ArgumentNullException("context");
            }

            if (entityType == null)
            {
                throw new ArgumentNullException("entityType");
            }

            EntityContainer container = context.MetadataWorkspace.GetEntityContainer(context.DefaultContainerName, DataSpace.CSpace);

            if (container == null)
            {
                return null;
            }

            EntitySetBase entitySet = container.BaseEntitySets.Where(item => item.ElementType.Name.Equals(entityType.Name))
                                                              .FirstOrDefault();

            return entitySet;
        }

    }


}
